# Standard Library Imports
import argparse
import os
import sys

# Third-Party Imports
import numpy as np
import torch
from PIL import Image
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm

# Project-Specific Imports
import CLIPAD
from dataset import *
from stable_diffusion import CustomStableDiffusionInpaintPipeline
from attn import *
from core import *
from utils import *


def main(args):
    device_id = args.device_id
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
    device = torch.device(f"cuda:{device_id}")

    if args.object in object_dictionary:
        object = object_dictionary[args.object]
    else:
        object = args.object

    model, _, _ = CLIPAD.create_model_and_transforms(model_name='ViT-B-16-plus-240', pretrained='laion400m_e32', precision='fp32')
    model = model.to(device)

    cross_attn_init()
    pipe = CustomStableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
    ).to(device)
    pipe.unet = set_layer_with_name_and_path(pipe.unet)
    pipe.unet = register_cross_attention_hook(pipe.unet)

    mask_image = to_pil_image(torch.zeros((512, 512)))
    
    #################################################################
    if args.shot:
        train_dataset = CustomDataset(args.dataset, args.object, args.data_path, args.shot)
        train_dataloader = DataLoader(dataset=train_dataset, batch_size=1, num_workers=64, shuffle=False)

        clip_memory_bank = build_clip_memory_bank(
            dataloader=train_dataloader,
            model=model,
            device=device,
        )

        diff_memory_bank = build_diff_memory_bank(
            dataloader=train_dataloader,
            pipe=pipe,
            mask_image=mask_image,
            object=object,
            template=args.vision_template,
            timesteps=args.vision_timesteps,
            blocks=args.vision_blocks,
        )
    #################################################################

    test_dataset = CustomDataset(args.dataset, args.object, args.data_path)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=64, shuffle=False)

    image = Image.open(next(iter(test_dataloader))[0][0]).convert("RGB")
    new_width, new_height = get_new_size(image.size)

    total_scores, total_gts = [], []

    for image_path, mask_path in tqdm(test_dataloader):
        image_path = image_path[0]
        mask_path = mask_path[0]

        clip_language_score_map = get_clip_language_score_map(
            image_path=image_path,
            model=model,
            object=object,
            template=args.clip_template,
            device=device,
        )

        diff_language_score_map = get_diff_language_score_map(
            pipe=pipe,
            image_path=image_path,
            mask_image=mask_image,
            object=object,
            states=args.language_states,
            template=args.language_template,
            timesteps=args.language_timesteps,
            blocks=args.language_blocks,
        )

        #################################################################
        if args.shot:
            clip_vision_score_map = get_clip_vision_score_map(
                image_path=image_path,
                model=model,
                memory_bank=clip_memory_bank,
                device=device,
            )

            diff_vision_score_map = get_diff_vision_score_map(
                pipe=pipe,
                image_path=image_path,
                mask_image=mask_image,
                object=object,
                template=args.vision_template,
                timesteps=args.vision_timesteps,
                blocks=args.vision_blocks,
                memory_bank=diff_memory_bank,
            )

            diff_language_score_map = adjust_scale(diff_language_score_map, diff_vision_score_map)
        #################################################################

        if args.shot:
            clip_score = clip_language_score_map + clip_vision_score_map
            diff_score = diff_language_score_map + diff_vision_score_map
            total_score = args.model_weight * clip_score + (1 - args.model_weight) * diff_score
        else:
            clip_score = clip_language_score_map
            diff_score = diff_language_score_map
            total_score = args.model_weight * clip_score + (1 - args.model_weight) * diff_score

        total_score = resize_tensor(total_score, new_width, new_height).cpu()
        total_scores.append(total_score)

        if mask_path == "":
            gt = torch.zeros((new_height, new_width))
        else:
            gt = resize_tensor(to_tensor(Image.open(mask_path)).squeeze(0), new_width, new_height)
            gt = (gt > 0.5).int()
        total_gts.append(gt)

    score_px = np.stack(total_scores)
    gt_px = np.stack(total_gts)

    auroc = roc_auc_score(gt_px.ravel(), score_px.ravel())
    aupro = cal_pro_score(gt_px, score_px)

    return auroc, aupro


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device_id', type=int)
    parser.add_argument('--seed', type=int)
    parser.add_argument('--task', type=str)
    parser.add_argument('--dataset', type=str, choices=['mvtec', 'visa'])
    parser.add_argument('--data_path', type=str)

    parser.add_argument('--shot', type=int)
    parser.add_argument('--model_weight', type=float)

    parser.add_argument('--clip_template', type=str)

    parser.add_argument('--language_template', type=str)
    parser.add_argument('--language_states', nargs='+', type=str)
    parser.add_argument('--language_timesteps', nargs='+', type=int)
    parser.add_argument('--language_blocks', nargs='+', type=str)

    parser.add_argument('--vision_template', type=str)
    parser.add_argument('--vision_timesteps', nargs='+', type=int)
    parser.add_argument('--vision_blocks', nargs='+', type=str)

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_args()

    logger = setup_logger("test", base_dir=os.path.join("log", args.task, args.dataset), device_id=args.device_id)
    logger.info("Segmentation:\n%s", sys.argv)

    set_seed(args.seed)

    auroc_list = []
    aupro_list = []

    for object in get_objects(args.dataset):
        args.object = object
        auroc, aupro = main(args)

        logger.info("\n%s", "Object: {}, AUROC={}, AUPRO={}".format(object, auroc, aupro))

        auroc_list.append(auroc)
        aupro_list.append(aupro)

    logger.info("\n%s", "AUROC={}, AUPRO={}".format(np.mean(auroc_list), np.mean(aupro_list)))
    logger.info("Segmentation:\n%s", sys.argv)
